import pandas as pd
import Bio.SeqIO
import numpy as np
import os
from keras.utils import np_utils
from keras import models
from keras import activations
from vis.visualization import visualize_saliency
from vis.utils import utils
from sys import argv

#set parameters from the command line, for example:
#DIVISION='subsample_assignment1'
#SUBSAMPLE='subsample2'
#PREDICTOR='Pro'

DIVISION=argv[1]
SUBSAMPLE=argv[2]
PREDICTOR=argv[3]

#set source of input data
SEQ_FILE_PRO='promoters.fa'
SEQ_FILE_TER='terminators.fa'
DATA_FILE='un_vs_expr.csv'
TRAINED_MODELS_DIRECTORY='trained_models_BEM/'
SALIENCY_DIRECTORY='saliency/'

if not os.path.exists(SALIENCY_DIRECTORY):
   os.makedirs(SALIENCY_DIRECTORY)

data=pd.read_csv(DATA_FILE,sep='\t')
data=data.set_index('Geneid')

geneIDs=[]
categories=[]
seqs=[]
seqs_pro=[]
seqs_ter=[]

# read sequence
for x,y in zip(Bio.SeqIO.parse(SEQ_FILE_PRO,'fasta'), Bio.SeqIO.parse(SEQ_FILE_TER,'fasta')):
     geneID=x.id
     seq_pro=str(x.seq)
     seq_ter=str(y.seq)
     seq=seq_pro+seq_ter
     if geneID in data.index:
         seqs.append(seq)
         seqs_pro.append(seq_pro)
         seqs_ter.append(seq_ter)
         geneIDs.append(geneID)
         category=data['category'][geneID]
         if category=='expressed':
              categories.append(0)
         else:
              categories.append(1)

geneIDs=np.array(geneIDs)
categories=np_utils.to_categorical(np.array(categories),2)
         
# one-hot encoding
dict={'A':[1,0,0,0],'C':[0,1,0,0],'G':[0,0,1,0],'T':[0,0,0,1]}

def one_hot_encoding(seq):
    one_hot_encoded=np.zeros(shape=(4,len(seq)))
    for i,nt in enumerate(seq):
        one_hot_encoded[:,i]=dict[nt]
    return one_hot_encoded

print("ONE-HOT ENCODING STARTED")
if PREDICTOR=='Pro_and_Ter':
    one_hot_seqs=np.expand_dims(np.array([ one_hot_encoding(seq) for seq in seqs],dtype=np.float32),3)
if PREDICTOR=='Pro':
    one_hot_seqs=np.expand_dims(np.array([ one_hot_encoding(seq) for seq in seqs_pro],dtype=np.float32),3)
if PREDICTOR=='Ter':
    one_hot_seqs=np.expand_dims(np.array([ one_hot_encoding(seq) for seq in seqs_ter],dtype=np.float32),3)
one_hot_seqs[:,:,1000:1003,:]=0
one_hot_seqs[:,:,1997:2000,:]=0
print("ONE-HOT ENCODING FINISHED")

un_vs_expr_summary=pd.read_csv(TRAINED_MODELS_DIRECTORY+'SUMMARY_FILE',sep='\t',header=None)
un_vs_expr_summary.columns=['name','splitting_id','testing_subsample','shuffle','tp','tn','fp','fn','accuracy','test_total','predictor']
row_indices=[ i for i in range(un_vs_expr_summary.shape[0]) if (un_vs_expr_summary['shuffle'][i]=='None' and un_vs_expr_summary['splitting_id'][i]==DIVISION and un_vs_expr_summary['testing_subsample'][i]==SUBSAMPLE) and un_vs_expr_summary['predictor'][i]==PREDICTOR ]

for row_index in row_indices:

    name=un_vs_expr_summary['name'][row_index]
    splitting_id=un_vs_expr_summary['splitting_id'][row_index]
    testing_subsample=un_vs_expr_summary['testing_subsample'][row_index]
    predictor=un_vs_expr_summary['predictor'][row_index]

    testing_genes=np.array([gene for gene in geneIDs if data[splitting_id][gene]==testing_subsample])
    testing_indices=np.array([np.where(geneIDs==gene)[0][0] for gene in testing_genes])
    testing_data=one_hot_seqs[testing_indices]
    testing_categories=np.argmax(categories[testing_indices],axis=1)
    # load model and make predictions
    model=models.load_model(TRAINED_MODELS_DIRECTORY+"MODEL_"+ str(name))
    prediction=np.argmax(model.predict(testing_data),axis=1)
    # modify model
    model.layers[-1].activation = activations.linear
    model=utils.apply_modifications(model)
    # create the directory for storing results if it does not exist
    foldername=SALIENCY_DIRECTORY+"%".join(["un_vs_expr",splitting_id,testing_subsample,predictor,str(name)])
    if not os.path.exists(foldername):
       os.makedirs(foldername)
    # calculate saliency map for each gene (only true positives and true negatives)
    for i,j in enumerate(testing_indices):
       testing_gene=geneIDs[j]
       if testing_categories[i]==1 and prediction[i]==1:
            filename="%".join([str(name),splitting_id,testing_subsample,predictor,testing_gene,'tp'])
            print(filename)
            if os.path.isfile("/".join([foldername,filename])+".npy"):
                is_empty=os.path.getsize("/".join([foldername,filename])+".npy")
                if is_empty==0:
                    grad=visualize_saliency(model,-1,filter_indices=1,seed_input=testing_data[i],backprop_modifier='guided',grad_modifier='relu')
                    np.save(file="/".join([foldername,filename]),arr=grad)
                else:
                    print("already done!")
            else:
                grad=visualize_saliency(model,-1,filter_indices=1,seed_input=testing_data[i],backprop_modifier='guided',grad_modifier='relu')
                np.save(file="/".join([foldername,filename]),arr=grad)
       elif testing_categories[i]==0 and prediction[i]==0:
            filename="%".join([str(name),splitting_id,testing_subsample,predictor,testing_gene,'tn'])
            print(filename)
            if os.path.isfile("/".join([foldername,filename])+".npy"):
                is_empty=os.path.getsize("/".join([foldername,filename])+".npy")
                if is_empty==0:
                    grad=visualize_saliency(model,-1,filter_indices=0,seed_input=testing_data[i],backprop_modifier='guided',grad_modifier='relu')
                    np.save(file="/".join([foldername,filename]),arr=grad)
                else:
                    print("already done!")
            else:
                grad=visualize_saliency(model,-1,filter_indices=0,seed_input=testing_data[i],backprop_modifier='guided',grad_modifier='relu')
                np.save(file="/".join([foldername,filename]),arr=grad)
